{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5HEkgJW62Zhq"
      },
      "source": [
        "Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "ZvnzHC7lmzWB"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mxPxpHKHMAkl"
      },
      "source": [
        "# Step 7: Improve model accuracy with data augmentation\n",
        "\n",
        "This is the notebook for step 7 of the codelab [**Build a handwritten digit classifier app with TensorFlow Lite**](https://codelabs.developers.google.com/codelabs/digit-classifier-tflite/).\n",
        "\n",
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/examples/blob/master/lite/codelabs/digit_classifier/ml/step7_improve_accuracy.ipynb\"\u003e\n",
        "    \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003e\n",
        "    Run in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/examples/blob/master/lite/codelabs/digit_classifier/ml/step7_improve_accuracy.ipynb\"\u003e\n",
        "    \u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003e\n",
        "    View source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w0qKcVsBNVyL"
      },
      "source": [
        "In previous steps, we trained a model that could recognize handwritten digits using the MNIST dataset. We were able to achieve above 98% accuracy on our validation dataset. However, when you deploy the model in an Android app and test it, you probably noticed some accuracy issue. Although the app was able to recognize digits that you drew, the accuracy is probably way lower than 98%.\n",
        "\n",
        "In this notebook, we will explore the cause of the accuracy drop and use data augmentation to improve deployment accuracy."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "p8bO0hupMdZM"
      },
      "source": [
        "## Preparation\n",
        "\n",
        "Let's start by importing TensorFlow and other supporting libraries that are used for data processing and visualization."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nImr6z7TMBJQ"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "from tensorflow import keras\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import random\n",
        "\n",
        "print(tf.__version__)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Rivc81WyRpXG"
      },
      "source": [
        "Import MNIST dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZGd5hkioMpcr"
      },
      "outputs": [],
      "source": [
        "mnist = keras.datasets.mnist\n",
        "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n",
        "\n",
        "# Normalize the input image so that each pixel value is between 0 to 1.\n",
        "train_images = train_images / 255.0\n",
        "test_images = test_images / 255.0\n",
        "\n",
        "# Add a color dimension to the images in \"train\" and \"validate\" dataset to\n",
        "# leverage Keras's data augmentation utilities later.\n",
        "train_images = np.expand_dims(train_images, axis=3)\n",
        "test_images = np.expand_dims(test_images, axis=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0SbDbl3HSHh1"
      },
      "source": [
        "Define an utility function so that we can create quickly create multiple models with the same model architecture for comparison."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JjoUeI5WSE_H"
      },
      "outputs": [],
      "source": [
        "def create_model():\n",
        "  model = keras.Sequential([\n",
        "    keras.layers.InputLayer(input_shape=(28, 28, 1)),\n",
        "    keras.layers.Conv2D(filters=32, kernel_size=(3, 3), activation=tf.nn.relu),\n",
        "    keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation=tf.nn.relu),\n",
        "    keras.layers.MaxPooling2D(pool_size=(2, 2)),\n",
        "    keras.layers.Dropout(0.25),\n",
        "    keras.layers.Flatten(),\n",
        "    keras.layers.Dense(10, activation=tf.nn.softmax)\n",
        "  ])\n",
        "  model.compile(optimizer='adam',\n",
        "                loss='sparse_categorical_crossentropy',\n",
        "                metrics=['accuracy'])\n",
        "  return model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QGgIJ4pYThkm"
      },
      "source": [
        "Confirm that our model can achieve above 98% accuracy on MNIST dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W2cYWUbkTb8Q"
      },
      "outputs": [],
      "source": [
        "base_model = create_model()\n",
        "base_model.fit(\n",
        "    train_images,\n",
        "    train_labels,\n",
        "    epochs=5,\n",
        "    validation_data=(test_images, test_labels)\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lEtLCGS0Ufag"
      },
      "source": [
        "# Troubleshoot the accuracy drop\n",
        "\n",
        "Let's see the digit images in MNIST again and guess the cause of the accuracy drop we experienced in deployment."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xptCfTnBTgvm"
      },
      "outputs": [],
      "source": [
        "# Show the first 25 images in the training dataset.\n",
        "plt.figure(figsize=(10,10))\n",
        "for i in range(25):\n",
        "  plt.subplot(5,5,i+1)\n",
        "  plt.xticks([])\n",
        "  plt.yticks([])\n",
        "  plt.grid(False)\n",
        "  plt.imshow(np.squeeze(train_images[i], axis=2), cmap=plt.cm.gray)\n",
        "  plt.xlabel(train_labels[i])\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "phB_hWqFWOQR"
      },
      "source": [
        "We can see from the 25 images above that the digits are about the same size, and they are in the center of the images. Let's verify if this assumption is true across the MNIST dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2V0KsvSLVE6I"
      },
      "outputs": [],
      "source": [
        "# An utility function that returns where the digit is in the image.\n",
        "def digit_area(mnist_image):\n",
        "  # Remove the color axes\n",
        "  mnist_image = np.squeeze(mnist_image, axis=2)\n",
        "\n",
        "  # Extract the list of columns that contain at least 1 pixel from the digit\n",
        "  x_nonzero = np.nonzero(np.amax(mnist_image, 0))\n",
        "  x_min = np.min(x_nonzero)\n",
        "  x_max = np.max(x_nonzero)\n",
        "\n",
        "  # Extract the list of rows that contain at least 1 pixel from the digit\n",
        "  y_nonzero = np.nonzero(np.amax(mnist_image, 1))\n",
        "  y_min = np.min(y_nonzero)\n",
        "  y_max = np.max(y_nonzero)\n",
        "\n",
        "  return [x_min, x_max, y_min, y_max]\n",
        "\n",
        "# Calculate the area containing the digit across MNIST dataset\n",
        "digit_area_rows = []\n",
        "for image in train_images:\n",
        "  digit_area_row = digit_area(image)\n",
        "  digit_area_rows.append(digit_area_row)\n",
        "digit_area_df = pd.DataFrame(\n",
        "  digit_area_rows,\n",
        "  columns=['x_min', 'x_max', 'y_min', 'y_max']\n",
        ")\n",
        "digit_area_df.hist()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GulcBpSRc3CO"
      },
      "source": [
        "Now from the histogram, you can confirm that the digit in MNIST images are fitted nicely in an certain area at the center of the images. \n",
        "\n",
        "![Mnist range](http://download.tensorflow.org/models/tflite/digit_classifier/mnist_range.png)\n",
        "\n",
        "However, when you wrote digits in your Android app, you probably did not pay attention to make sure your digit fit in the virtual area that the digits appear in MNIST dataset. The machine learning model have not seen such data before so it performed poorly, especially when you wrote a digit that was off the center of the drawing pad.\n",
        "\n",
        "Let's add some data augmentation to the MNIST dataset to verify if our assumption is true. We will distort our MNIST dataset by adding:\n",
        "* Rotation\n",
        "* Width and height shift\n",
        "* Shear\n",
        "* Zoom"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AC5l3W7bY1td"
      },
      "outputs": [],
      "source": [
        "# Define data augmentation\n",
        "datagen = keras.preprocessing.image.ImageDataGenerator(\n",
        "  rotation_range=30,\n",
        "  width_shift_range=0.25,\n",
        "  height_shift_range=0.25,\n",
        "  shear_range=0.25,\n",
        "  zoom_range=0.2\n",
        ")\n",
        "\n",
        "# Generate augmented data from MNIST dataset\n",
        "train_generator = datagen.flow(train_images, train_labels)\n",
        "test_generator = datagen.flow(test_images, test_labels)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wy1IfuRZjjec"
      },
      "source": [
        "Let's see what our digit images look like after augmentation. You can see that we now clearly have much more variation on how the digits are placed in the images."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1G-tWDc2aia1"
      },
      "outputs": [],
      "source": [
        "augmented_images, augmented_labels = next(train_generator)\n",
        "plt.figure(figsize=(10,10))\n",
        "for i in range(25):\n",
        "    plt.subplot(5,5,i+1)\n",
        "    plt.xticks([])\n",
        "    plt.yticks([])\n",
        "    plt.grid(False)\n",
        "    plt.imshow(np.squeeze(augmented_images[i], axis=2), cmap=plt.cm.gray)\n",
        "    plt.xlabel('Label: %d' % augmented_labels[i])\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G-zys-yNkL4b"
      },
      "source": [
        "Let's evaluate the digit classifier model that we trained earlier on this augmented test dataset and see if it makes accuracy drop."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IPnP4xPbjqWi"
      },
      "outputs": [],
      "source": [
        "base_model.evaluate(test_generator)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AgNlTPPVlDqX"
      },
      "source": [
        "You can see that accuracy significantly dropped to below 40% in augmented test dataset."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xls8oqBnlcZj"
      },
      "source": [
        "# Improve accuracy with data augmentation\n",
        "\n",
        "Now let's train our model using augmented dataset to make it perform better in deployment."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DX4XRgc9k-6s"
      },
      "outputs": [],
      "source": [
        "improved_model = create_model()\n",
        "improved_model.fit(train_generator, epochs=5, validation_data=test_generator)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3UlqZ1XnnJRu"
      },
      "source": [
        "We can see that as the models saw more distorted digit images during training, its accuracy evaluated distorted test digit images were significantly improved to about 90%."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fsI4mHtbpE_A"
      },
      "source": [
        "# Convert to TensorFlow Lite\n",
        "\n",
        "Let's convert the improved model to TensorFlow Lite and redeploy to the Android app."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a0l6HcBTmh7-"
      },
      "outputs": [],
      "source": [
        "# Convert Keras model to TF Lite format and quantize.\n",
        "converter = tf.lite.TFLiteConverter.from_keras_model(improved_model)\n",
        "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
        "tflite_quantized_model = converter.convert()\n",
        "\n",
        "# Save the quantized model to file to the Downloads directory\n",
        "f = open('mnist.tflite', \"wb\")\n",
        "f.write(tflite_quantized_model)\n",
        "f.close()\n",
        "\n",
        "# Download the digit classification model\n",
        "from google.colab import files\n",
        "files.download('mnist.tflite')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CAp7mKIkrfjY"
      },
      "source": [
        "## Good job!\n",
        "This is the end of *Improve model accuracy with data augmentation* in the codelab **Build a handwritten digit classifier app with TensorFlow Lite**. You can repeat [step 3](https://codelabs.developers.google.com/codelabs/digit-classifier-tflite/#2) to redeploy the improved model to your Android app and see if accuracy has been improved."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "step7_improve_accuracy.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
